// STL
#include <cmath>                        // std::sqrt()
#include <complex>                      // std::complex<>, ::real(), ::imag()
#include <stdexcept>                    // std::runtime_error()
#include <string>                       // std::to_string()
#include <tuple>                        // std::tie()
#include <vector>                       // std::vector<>

// jScience
#include "jScience/linalg.hpp"          // Matrix<>, *, eigen(), transpose()
#include "jScience/stl/MultiVector.hpp" // MultiVector<>
#include "jutility.hpp"                 // sort_indices_desc()


//
// DESC: Determine the whitening matrix, according to the ZCA(zero-phase components analysis)-cor transformation.
//
// -----
//
// INPUT:
//
//     Sigma :: covariance matrix
//
// -----
//
// NOTE: The covariance matrix is a real, symmetric matrix.
//
// -----
//
// NOTE: No error check(s) are made on the input arguments.
//
// =====
//
// NOTE: This whitening is the unique procedure that ensures that the components of the whitened vector remain maximally correlated with the corresponding components of the original variables.
//
// -----
//
// NOTE: See:
//
//     A. Kessy, A. Lewin, and K. Strimmer, "Optimal Whitening and Decorrelation" (2016)
//
inline Matrix<double> whitening::ZCAcor(
                                        const Matrix<double> &Sigma
                                        )
{
    Matrix<double> W;

    //=========================================================
    // INITIALIZATION
    //=========================================================

    int N = Sigma.size(0);

    //=========================================================
    // CALCULATE (ADDITIONAL) MATRICES (FROM Sigma)
    //=========================================================

    // INVERSE SQUARE-ROOT OF THE VARIANCE MATRIX
    //
    // NOTE: The positive square-roots are taken (as they should be --- giving inverse standard deviations).
    //
    Matrix<double> V_invsqrt(N,N, 0.);

    for( int i = 0; i < N; ++i )
    {
        V_invsqrt(i,i) = 1./std::sqrt(Sigma(i,i));
    }

    // CORRELATION MATRIX
    //
    // NOTE: For a discussion related to the following calculation, see:
    //
    //     https://en.wikipedia.org/wiki/Covariance_matrix#Relation_to_the_matrix_of_correlation_coefficients
    //
    // NOTE: This matrix is also a real, symmetric matrix.
    //
    Matrix<double> corr = V_invsqrt*Sigma*V_invsqrt;

    //=========================================================
    // EIGENDECOMPOSITION (OF corr)
    //=========================================================

    Matrix<double> Lambda;
    Matrix<double> Q;

    // NOTE: Because of the form of the correlation matrix (real, symmetric), ...
    //
    // NOTE: ... its eigendecomposition is
    //
    //     A = Q*Lambda*Q^-1 = Q*Lambda*Q^T
    //
    //     where:
    //
    //     Q is an orthogonal matrix whose columns are the eigenvectors of A
    //     Lambda is a diagonal matrix whose entries are the eigenvalues of A
    //
    // NOTE: See:
    //
    //     https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix#Real_symmetric_matrices

    // NOTE:
    //
    //          int                               INFO  :: DGEEV exit code
    //          std::vector<std::complex<double>> W     :: Eigenvalues
    //          MultiVector<std::complex<double>> VL    :: Left eigenvectors
    //          MultiVector<std::complex<double>> VR    :: Right eigenvectors
    //
    int                               INFO;
    std::vector<std::complex<double>> _W;
    MultiVector<std::complex<double>> VL;
    MultiVector<std::complex<double>> VR;
    std::tie(
             INFO,
             _W,
             VL,
             VR
             ) = eigen(
                       corr
                       );

    // CHECK FOR COMPLETION
    if( INFO != 0 )
    {
        throw std::runtime_error(("error in whitening::ZCAcor(): eigen() returned INFO = " + std::to_string(INFO) + " != 0"));
    }

    // CHECK FOR (AND STORE AS) REAL EIGENVALUES
    std::vector<double> _W_r;

    for( int i = 0; i < N; ++i )
    {
        if( std::imag(_W[i]) != 0. )
        {
            throw std::runtime_error(("error in whitening::ZCAcor(): eigenvalue returned from eigen() is complex"));
        }

        _W_r.push_back(std::real(_W[i]));
    }

    // STORE SORTED EIGENVALUES
    //
    // NOTE: Following the standard convention, the eigenvalues should be sorted in order from largest to smallest.

    Lambda = Matrix<double>(N,N, 0.);
    Q      = Matrix<double>(N,N);

    int _j = 0;

    for( auto &idx : sort_indices_desc(_W_r) )
    {
        Lambda(_j,_j) = _W_r[idx];

        for( int i = 0; i < N; ++i )
        {
            // CHECK FOR REAL EIGENVECTORS
            //
            // NOTE: The right eigenvectors [returned from eigen()] are (also) stored as column vectors.
            //
            if( std::imag(VR(i,idx)) != 0. )
            {
                throw std::runtime_error(("error in whitening::ZCAcor(): (right) eigenvector returned from eigen() is complex"));
            }

            Q(i,_j) = std::real(VR(i,idx));
        }

        ++_j;
    }

    //=========================================================
    // CALCULATE (ADDITIONAL) MATRICES
    //=========================================================

    // INVERSE SQUARE-ROOT OF THE EIGENVALUE MATRIX
    //
    // NOTE: The positive square-roots are (again) taken.
    //
    Matrix<double> Lambda_invsqrt(N,N, 0.);

    for( int i = 0; i < N; ++i )
    {
        Lambda_invsqrt(i,i) = 1./std::sqrt(Lambda(i,i));
    }

    // INVERSE SQUARE-ROOT OF THE CORRELATION MATRIX
    //
    // NOTE: The following gives the unique inverse matrix square root of the correlation matrix.
    //
    // NOTE: [Subject (only) to taking the positive square-root of the eigenvalue matrix (just above).]
    //
    Matrix<double> corr_invsqrt = Q*Lambda_invsqrt*transpose(Q);

    //=========================================================
    // CALCULATE WHITENING MATRIX
    //=========================================================

    // NOTE: Eq. (11).
    //
    W = corr_invsqrt*V_invsqrt;

    //=========================================================
    // CLEANUP & RETURN
    //=========================================================

    return W;
}
